// Copyright © 2025 Apple Inc.
// Copyright © 2008-2013 NVIDIA Corporation
// Copyright © 2013 Filipe RNC Maia
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Forked from
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h

// TODO: We should use thrust::exp but the thrust header in old CUDA versions
// can not be used in JIT.

#pragma once

#include <metal_math>

using ieee_float_shape_type = union {
  float value;
  uint32_t word;
};

inline void get_float_word(thread uint32_t& i, float d) {
  ieee_float_shape_type gf_u;
  gf_u.value = (d);
  (i) = gf_u.word;
}

inline void get_float_word(thread int32_t& i, float d) {
  ieee_float_shape_type gf_u;
  gf_u.value = (d);
  (i) = gf_u.word;
}

inline void set_float_word(thread float& d, uint32_t i) {
  ieee_float_shape_type sf_u;
  sf_u.word = (i);
  (d) = sf_u.value;
}

inline float frexp_expf(float x, thread int* expt) {
  const uint32_t k = 235;
  const float kln2 = 162.88958740F;

  float exp_x;
  uint32_t hx;

  exp_x = metal::exp(x - kln2);
  get_float_word(hx, exp_x);
  *expt = (hx >> 23) - (0x7f + 127) + k;
  set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
  return exp_x;
}

inline complex64_t ldexp_cexpf(complex64_t z, int expt) {
  float x, y, exp_x, scale1, scale2;
  int ex_expt, half_expt;

  x = z.real;
  y = z.imag;
  exp_x = frexp_expf(x, &ex_expt);
  expt += ex_expt;

  half_expt = expt / 2;
  set_float_word(scale1, (0x7f + half_expt) << 23);
  half_expt = expt - half_expt;
  set_float_word(scale2, (0x7f + half_expt) << 23);

  return complex64_t{
      metal::cos(y) * exp_x * scale1 * scale2,
      metal::sin(y) * exp_x * scale1 * scale2};
}

inline complex64_t cexpf(const thread complex64_t& z) {
  float x, y, exp_x;
  uint32_t hx, hy;

  const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;

  x = z.real;
  y = z.imag;

  get_float_word(hy, y);
  hy &= 0x7fffffff;

  /* cexp(x + I 0) = exp(x) + I 0 */
  if (hy == 0) {
    return complex64_t{metal::exp(x), y};
  }
  get_float_word(hx, x);
  /* cexp(0 + I y) = cos(y) + I sin(y) */
  if ((hx & 0x7fffffff) == 0) {
    return complex64_t{metal::cos(y), metal::sin(y)};
  }
  if (hy >= 0x7f800000) {
    if ((hx & 0x7fffffff) != 0x7f800000) {
      /* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
      return complex64_t{y - y, y - y};
    } else if (hx & 0x80000000) {
      /* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
      return complex64_t{0.0, 0.0};
    } else {
      /* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
      return complex64_t{x, y - y};
    }
  }

  if (hx >= exp_ovfl && hx <= cexp_ovfl) {
    /*
     * x is between 88.7 and 192, so we must scale to avoid
     * overflow in expf(x).
     */
    return ldexp_cexpf(z, 0);
  } else {
    /*
     * Cases covered here:
     *  -  x < exp_ovfl and exp(x) won't overflow (common case)
     *  -  x > cexp_ovfl, so exp(x) * s overflows for all s > 0
     *  -  x = +-Inf (generated by exp())
     *  -  x = NaN (spurious inexact exception from y)
     */
    exp_x = metal::exp(x);
    return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)};
  }
}
